package org.zjvis.datascience.spark.algorithm;

import org.apache.commons.lang3.StringUtils;
import org.apache.spark.Accumulator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.fpm.PrefixSpan;
import org.apache.spark.mllib.fpm.PrefixSpanModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import org.zjvis.datascience.spark.util.OptionHelper;
import org.zjvis.datascience.spark.util.OutputResult;
import org.zjvis.datascience.spark.util.UtilTool;
import scala.Tuple3;
import scala.collection.JavaConversions;

import java.util.*;

/**
 * @description Spark-Prefix Span算子
 * @date 2021-12-23
 */
public class PrefixSpanAlgorithm extends BaseAlgorithm {

    private String[] featureCols;

    private String sourceTable;

    private String targetTable;

    private double minSupport;

    private int minFreq;

    public boolean parseParams(String[] args) {
        super.parseParams(args);

        options.addOption(OptionHelper.getSpecOption("source"));
        options.addOption(OptionHelper.getSpecOption("target"));
        options.addOption(OptionHelper.getSpecOption("featureCols"));
        options.addOption(OptionHelper.getSpecOption("mins", "minSupport", "min support", true));
        options.addOption(OptionHelper.getSpecOption("minf", "minFreq", "min frequency", true));

        boolean flag = this.initCommandLine(args);

        if (!flag) {
            logger.error("initCommandLine fail!!!");
            return false;
        }
        if (cmd.hasOption("f") || cmd.hasOption("featureCols")) {
            featureCols = cmd.getOptionValue("featureCols").split(",");
        }
        if (cmd.hasOption("s") || cmd.hasOption("source")) {
            sourceTable = cmd.getOptionValue("source");
        }
        if (cmd.hasOption("t") || cmd.hasOption("target")) {
            targetTable = cmd.getOptionValue("target");
        }
        if (cmd.hasOption("mins") || cmd.hasOption("minSupport")) {
            minSupport = Double.parseDouble(cmd.getOptionValue("minSupport"));
        }
        if (cmd.hasOption("minf") || cmd.hasOption("minFreq")) {
            minFreq = Integer.parseInt(cmd.getOptionValue("minFreq"));
        }

        return true;
    }

    public PrefixSpanAlgorithm(SparkSession sparkSession) {
        super(sparkSession);
    }

    public boolean beginAlgorithm() {
        Dataset<Row> dataset;
        if (isHive) {
            String sql = UtilTool.buildSelectSql(featureCols, sourceTable, "");
            dataset = sparkSession.sql(sql);
        } else {
            if (isSample) {
                String cacheTable = cacheUtil.modifyCacheTableName(sourceTable);
                if (cacheUtil.isCacheTableExists(cacheTable)) {
                    dataset = sparkSession.table(cacheTable)
                            .select(JavaConversions.asScalaBuffer(UtilTool.selectColumns(featureCols, "")));
                } else {
                    Dataset<Row> cacheRdd = UtilTool.readFromGreenPlum(sparkSession, sourceTable, "", sampleNumber);
                    dataset = cacheRdd.select(JavaConversions.asScalaBuffer(UtilTool.selectColumns(featureCols, "")));
                    if (!cacheUtil.cacheTableForDataset(cacheRdd, cacheTable)) {
                        logger.error("table = {} cache fail!!!", cacheTable);
                        return false;
                    }
                    logger.info("table = {} cache success", cacheTable);
                }
            } else {
                dataset = UtilTool.readFromGreenPlum(sparkSession, sourceTable, "")
                        .select(JavaConversions.asScalaBuffer(UtilTool.selectColumns(featureCols, "")));
            }
        }

        int length = featureCols.length;
        JavaRDD<List<List<String>>> tansactions = dataset.toJavaRDD().map(row-> {
            List<List<String>> lists = new ArrayList<>();
            for (int i = 0; i < length; ++i) {
                String s = row.get(i) == null ? "" : row.get(i).toString();
                if (StringUtils.isNotEmpty(s)) {
                    lists.add(Arrays.asList(s.split(",")));
                }
            }
            return lists;
        }).persist(StorageLevel.MEMORY_AND_DISK());

        PrefixSpan prefixSpan = new PrefixSpan()
                .setMinSupport(minSupport);
        PrefixSpanModel<String> model = prefixSpan.run(tansactions);

        long totalCount = tansactions.count();
//        Accumulator<Integer> rowIndex =  sparkSession.sparkContext().accumulator(0, "rowIndex", null);
        JavaRDD<Row> resultRdd = model.freqSequences().toJavaRDD()
                .mapPartitions(iter -> {
                    ArrayList<Tuple3<String, Double, Long>> result = new ArrayList<>();
                    while (iter.hasNext()) {
                        PrefixSpan.FreqSequence<String> freqSequence = iter.next();
                        if (freqSequence.freq() >= minFreq) {
                            result.add(new Tuple3<>(freqSequence.javaSequence().toString(), freqSequence.freq() * 1.0 / totalCount, freqSequence.freq()));
                        }
                    }
                    return result.iterator();
                }).zipWithUniqueId()
                .map(x -> RowFactory.create(Integer.parseInt(String.valueOf(x._2)), x._1._1(), x._1._2(), x._1._3()));
//        JavaRDD<Row> resultRdd = model.freqSequences().toJavaRDD()
//               .map(row -> RowFactory.create(row.javaSequence().toString(), row.freq() * 1.0 / totalCount, row.freq()))
//               .filter(x->x.getLong(2) >= minFreq);

        StructType structType = DataTypes.createStructType(new StructField[]{
                DataTypes.createStructField("_record_id_", DataTypes.IntegerType, true),
                DataTypes.createStructField("pattern", DataTypes.StringType, true),
                DataTypes.createStructField("support", DataTypes.DoubleType, true),
                DataTypes.createStructField("frequency", DataTypes.LongType, true)
        });

        Dataset<Row> df = sparkSession.createDataFrame(resultRdd, structType);

        if (isHive) {
            this.saveHiveTable(df, targetTable);
        } else {
            UtilTool.saveGreenplumTable(df, targetTable);
        }

        Map<String, Object> inputKV = new HashMap<>();
        List<String> outputTables = new ArrayList<>();
        outputTables.add(targetTable);

        String metaPath = String.format(UtilTool.RUN_META_PATH, algorithmName, uniqueKey);

        OutputResult outputResult = this.buildOutputResult(outputTables, 0, "", inputKV);

        if (isHive) {
            this.saveOutputResultForHive(sparkSession, outputResult, metaPath);
        } else {
            this.saveOutputResultForMysql(sparkSession, outputResult);
        }
        retResult = outputResult;
        tansactions.unpersist();
        return true;
    }
}
